import torch
from utils.log_helper import log_init
from tqdm import tqdm

if __name__ == '__main__':
    logger = log_init("test_code", level='INFO')
    # logits = torch.randn(size=(2, 5, 10))
    # logger.debug(f'logits = {logits}')
    # y_pred = logits.argmax(axis=2)
    # logger.debug(f'y_pred = {y_pred}')
    # logger.debug(f'flatten y_pred = {y_pred.reshape(-1)}')

    # arr = [(1, 2), (3, 4)]
    # for idx, (src, tgt) in tqdm(list(enumerate(arr)), ncols=100):
    #     logger.debug(f'src = {src}, tgt = {tgt}')
    #     pass
    logger.debug('123')
    logger.info('123')
